{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright (c) MONAI Consortium  \n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    "you may not use this file except in compliance with the License.  \n",
    "You may obtain a copy of the License at  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  \n",
    "Unless required by applicable law or agreed to in writing, software  \n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    "See the License for the specific language governing permissions and  \n",
    "limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build detection dataset\n",
    "\n",
    "This tutorial shows how to build the detection dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly\"\n",
    "!python -c \"import pandas\" || pip install -q pandas"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "import yaml\n",
    "import pandas as pd\n",
    "from monai.config import print_config\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load useful data\n",
    "\n",
    "As described in `readme.md`, we manually labeled 1126 frames in order to build the detection model.\n",
    "Please download the manually labeled bounding boxes from [google drive](https://developer.download.nvidia.com/assets/Clara/monai/tutorials/1126_frame_labels.zip), the uncompressed folder `labels` is saved into `label_14_tools_yolo_640_blur/`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# please run `preprocess_detect_scene_and_split_fold.ipynb` first to produce this file\n",
    "df = pd.read_csv(\"train_fold_balanced.csv\")\n",
    "dataset_dir = \"/raid/label_14_tools_yolo_640_blur/\"\n",
    "labels_dir = os.path.join(dataset_dir, \"labels\")\n",
    "# please run `preprocess_prepare_detection_dataset.ipynb` first to produce frames\n",
    "images_dir = \"/raid/surg/image640_blur/\"\n",
    "yolo_image_dir = os.path.join(dataset_dir, \"images\")\n",
    "\n",
    "os.makedirs(yolo_image_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare images\n",
    "\n",
    "Copy video extracted frames into yolo dataset dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for label in os.listdir(labels_dir):\n",
    "    image = label.replace(\"txt\", \"jpg\")\n",
    "\n",
    "    image_src = os.path.join(images_dir, image)\n",
    "    image_dst = os.path.join(yolo_image_dir, image)\n",
    "    shutil.copy(image_src, image_dst)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare yaml files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_samples = pd.DataFrame.from_dict(os.listdir(yolo_image_dir))\n",
    "df_samples.columns = [\"image_name\"]\n",
    "df_samples[\"clip_name\"] = df_samples[\"image_name\"].apply(lambda x: \"clip_\" + x.split(\"_\")[1])\n",
    "df = df.merge(df_samples, on=\"clip_name\", how=\"right\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_dict = {\n",
    "    \"bipolar dissector\": 0,\n",
    "    \"bipolar forceps\": 1,\n",
    "    \"cadiere forceps\": 2,\n",
    "    \"clip applier\": 3,\n",
    "    \"force bipolar\": 4,\n",
    "    \"grasping retractor\": 5,\n",
    "    \"monopolar curved scissors\": 6,\n",
    "    \"needle driver\": 7,\n",
    "    \"permanent cautery hook/spatula\": 8,\n",
    "    \"prograsp forceps\": 9,\n",
    "    \"stapler\": 10,\n",
    "    \"suction irrigator\": 11,\n",
    "    \"tip-up fenestrated grasper\": 12,\n",
    "    \"vessel sealer\": 13,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(5):\n",
    "    df_train = df[df[\"fold\"] != i]\n",
    "    df_val = df[df[\"fold\"] == i]\n",
    "\n",
    "    with open(os.path.join(dataset_dir, \"train_fold{}.txt\".format(i)), \"w\") as f:\n",
    "        for path in df_train.image_name.tolist():\n",
    "            full_path = os.path.join(dataset_dir, \"images\", path)\n",
    "            f.write(full_path + \"\\n\")\n",
    "\n",
    "    with open(os.path.join(dataset_dir, \"val_fold{}.txt\".format(i)), \"w\") as f:\n",
    "        for path in df_val.image_name.tolist():\n",
    "            full_path = os.path.join(dataset_dir, \"images\", path)\n",
    "            f.write(full_path + \"\\n\")\n",
    "        print(len(df_train.image_name.tolist()), len(df_val.image_name.tolist()))\n",
    "\n",
    "    data = {\n",
    "        \"path\": dataset_dir,\n",
    "        \"train\": os.path.join(dataset_dir, \"train_fold{}.txt\".format(i)),\n",
    "        \"val\": os.path.join(dataset_dir, \"val_fold{}.txt\".format(i)),\n",
    "        \"nc\": len(list(label_dict.keys())),\n",
    "        \"names\": list(label_dict.keys()),\n",
    "    }\n",
    "\n",
    "    with open(os.path.join(dataset_dir, \"surg_14cls_fold{}.yaml\".format(i)), \"w\") as outfile:\n",
    "        yaml.dump(data, outfile, default_flow_style=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
